{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# Example prediction on a single test image\n",
    "\n",
    "This notebook gives example code to make a single disparity prediction for one test image.\n",
    "\n",
    "The file `test_simple.py` shows a more complete version of this code, which additionally:\n",
    "- Can run on GPU or CPU (this notebook only runs on CPU)\n",
    "- Can predict for a whole folder of images, not just a single image\n",
    "- Saves predictions to `.npy` files and disparity images."
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "source": [
    "from __future__ import absolute_import, division, print_function\n",
    "%matplotlib inline\n",
    "\n",
    "import os\n",
    "import numpy as np\n",
    "import PIL.Image as pil\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torch\n",
    "from torchvision import transforms\n",
    "\n",
    "import networks\n",
    "from utils import download_model_if_doesnt_exist"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Setting up network and loading weights"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "source": [
    "model_name = \"mono_640x192\"\n",
    "\n",
    "download_model_if_doesnt_exist(model_name)\n",
    "encoder_path = os.path.join(\"models\", model_name, \"encoder.pth\")\n",
    "depth_decoder_path = os.path.join(\"models\", model_name, \"depth.pth\")\n",
    "\n",
    "# LOADING PRETRAINED MODEL\n",
    "encoder = networks.ResnetEncoder(18, False)\n",
    "depth_decoder = networks.DepthDecoder(num_ch_enc=encoder.num_ch_enc, scales=range(4))\n",
    "\n",
    "loaded_dict_enc = torch.load(encoder_path, map_location='cpu')\n",
    "filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in encoder.state_dict()}\n",
    "encoder.load_state_dict(filtered_dict_enc)\n",
    "\n",
    "loaded_dict = torch.load(depth_decoder_path, map_location='cpu')\n",
    "depth_decoder.load_state_dict(loaded_dict)\n",
    "\n",
    "encoder.eval()\n",
    "depth_decoder.eval();"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "-> Downloading pretrained model to models/mono_640x192.zip\n",
      "   Unzipping model...\n",
      "   Model unzipped to models/mono_640x192\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Loading the test image and preprocessing"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "source": [
    "image_path = \"assets/test_image.jpg\"\n",
    "\n",
    "input_image = pil.open(image_path).convert('RGB')\n",
    "original_width, original_height = input_image.size\n",
    "\n",
    "feed_height = loaded_dict_enc['height']\n",
    "feed_width = loaded_dict_enc['width']\n",
    "input_image_resized = input_image.resize((feed_width, feed_height), pil.LANCZOS)\n",
    "\n",
    "input_image_pytorch = transforms.ToTensor()(input_image_resized).unsqueeze(0)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Prediction using the PyTorch model"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "source": [
    "with torch.no_grad():\n",
    "    features = encoder(input_image_pytorch)\n",
    "    outputs = depth_decoder(features)\n",
    "\n",
    "disp = outputs['disp0']"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Plotting"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "disp_resized = torch.nn.functional.interpolate(disp,\n",
    "    (original_height, original_width), mode=\"bilinear\", align_corners=False)\n",
    "\n",
    "# Saving colormapped depth image\n",
    "disp_resized_np = disp_resized.squeeze().cpu().numpy()\n",
    "vmax = np.percentile(disp_resized_np, 95)\n",
    "\n",
    "plt.figure(figsize=(10, 10))\n",
    "plt.subplot(211)\n",
    "plt.imshow(input_image)\n",
    "plt.title(\"Input\", fontsize=22)\n",
    "plt.axis('off')\n",
    "\n",
    "plt.subplot(212)\n",
    "plt.imshow(disp_resized_np, cmap='magma', vmax=vmax)\n",
    "plt.title(\"Disparity prediction\", fontsize=22)\n",
    "plt.axis('off');"
   ],
   "outputs": [],
   "metadata": {}
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}